

clearvars -except diffRadius MAP


%% data (including out of sample)
load('../../true_DGP/priorWSigma.mat')
data=readtable('../../../Data/data_main_sample.csv');
dd=readtable('../../../Data/distance_capitals.csv');

data=data(data.year>=1950&isfinite(data.D),:);
ccodes=unique(data.ccode);
years=unique(data.year(data.year>1950));
N=length(ccodes);
T=length(years);
y=zeros(N,T); D=y; M=y;
for n=1:N
    dn=data(data.ccode==ccodes(n),:);
    for t=1:T
        if isempty(setdiff(years(t),dn.year))
            if isfinite(dn.y(years(t)==dn.year))
            	M(n,t)=1; % not missing
                y(n,t)=dn.y(years(t)==dn.year);
                D(n,t)=dn.D(years(t)==dn.year);
            end
        end
    end
end

Z=10^(-6)*table2array(dd(:,2:end));
KZ=size(Z,3);

% democracy diffusion/contagion
X=zeros(N,T);
nm0=1*arrayfun(@(n)isempty(setdiff(ccodes(n),data.ccode(data.year==1950&isfinite(data.D)))),1:length(ccodes))';
D0=arrayfun(@(n)data.D(data.year==1950&data.ccode==ccodes(n)),1:length(ccodes),'UniformOutput',0)';
D0(nm0==0)={0};
D0=cell2mat(D0);
for n=1:N
    for t=1:T
        if M(n,t)==1
            if t==1
                weights=nm0.*exp(-diffRadius*Z(:,n));
                weights(n)=0;
                weights=weights/sum(weights);
                X(n,t)=weights'*D0;
            else
                weights=M(:,t-1).*exp(-diffRadius*Z(:,n));
                weights(n)=0;
                weights=weights/sum(weights);
                X(n,t)=weights'*D(:,t-1);
            end
        end
    end
end
X=M.*(X-sum(sum(X(:,1:50)))/sum(sum(M(:,1:50))));
K=1;
clear nm0 D0 t weights

ep_sd=sqrt(diag(Sigma));

% Sparse-grid integration
[sgi_nodes,sgi_weights]=nwspgr('KPN',2,8);

cnames=cell(N,1);
for n=1:N
    a=unique(data.idacr(data.ccode==ccodes(n)));
    cnames{n}=a{1};
end

ccodes=table(ccodes,cnames);

clear a cnames data dd n


%% parameter estimates
alpha=MAP(1:N);
theta=MAP(N+(1:2));
beta=MAP(N+2+(1:(2*N)));
v=MAP(3*N+2+(1:N));
f=MAP(4*N+2+(1:N));
vs=MAP(5*N+2+(1:N));
xi=MAP(6*N+2+(1:K));
gamma=MAP(6*N+2+K+(1:KZ));


%% log-likelihood
ll=-log_posterior(MAP,D(:,1:50),M(:,1:50),X(:,1:50,:),Z,a0,om_a,th0,om_th,b0A,b0D,om_b,s_v,d_v,f0,om_f,s_vs,d_vs)-...
    (sum(-0.5*((alpha-a0)/om_a).^2)+sum(-0.5*((theta-th0)/om_th).^2)+...
    sum(-0.5*((beta(1:N)-b0A)/om_b).^2)+sum(-0.5*((beta(N+(1:N))-b0D)/om_b).^2)+...
    sum(-(s_v+1)*log(v)-d_v*v.^(-1))+sum(-0.5*((f-f0)/om_f).^2)+...
    sum(-(s_vs+1)*log(vs)-d_vs*vs.^(-1)));


%% one-step-ahead forecasts
ZZ=zeros(N,N);
for kz=1:KZ
    ZZ=ZZ+Z(:,:,kz)*gamma(kz);
end 
Q=diag(v.*ep_sd)*exp(-ZZ)*diag(v.*ep_sd);
P=kron(eye(2),tril(Q)+tril(Q,-1)');
P=P\eye(2*N);
P=repmat(P,1,1,T);

XX=zeros(N,T);
Pd=[sqrt(diag(P(:,:,1)\eye(2*N))),zeros(2*N,T-1)];
B=[beta,zeros(2*N,T-1)];
for t=1:T-1
    XX(:,t)=permute(X(:,t,:),[1,3,2])*xi;
    P(:,:,t+1)=P(:,:,t);
    B(:,t+1)=B(:,t);
    nm=find(M(:,t)==1);
    DD=[diag(1-D(nm,t)),diag(D(nm,t))];
    P([nm;N+nm],[nm;N+nm],t+1)=P([nm;N+nm],[nm;N+nm],t+1)+DD'*(Sigma(nm,nm)\DD);
    Pd(:,t+1)=sqrt(diag(P(:,:,t+1)\eye(2*N)));
    B([nm;N+nm],t+1)=B([nm;N+nm],t+1)+P([nm;N+nm],[nm;N+nm],t+1)\(DD'*(Sigma(nm,nm)\(y(nm,t)-DD*B([nm;N+nm],t+1))));
end
XX(:,T)=permute(X(:,T,:),[1,3,2])*xi;

NS=length(sgi_weights);
U0=zeros(N*T,NS);
U1=zeros(N*T,NS);

for ns=1:NS
	U0(:,ns)=reshape(exp(alpha+theta(1)*(sgi_nodes(ns,1)*Pd(1:N,:)+B(1:N,:)+sgi_nodes(ns,2)*ep_sd)),N*T,1);
    U1(:,ns)=reshape(exp(alpha+theta(2)*(sgi_nodes(ns,1)*Pd(N+(1:N),:)+B(N+(1:N),:)+sgi_nodes(ns,2)*ep_sd)-f-XX),N*T,1);
end
U0=reshape(M,N*T,1).*((U0./(1+U0))*sgi_weights);
U1=reshape(M,N*T,1).*((U1./(1+U1))*sgi_weights);

D_hat=1*reshape(U1>=U0,N,T);


%% observed and predicted transitions into or out of democracy
trns=zeros(N,T);
trns_h=trns;
for n=1:N
    for t=2:T
        if M(n,t-1)==1&&M(n,t)==1
            trns(n,t)=D(n,t)-D(n,t-1);
            trns_h(n,t)=D_hat(n,t)-D_hat(n,t-1);
        end
    end
end

% correct predictions at exact and 5-year windows
correct0=0; correct5=0;
for n=1:N
    for t=2:50
        if trns(n,t)~=0
            correct0=correct0+1*(trns_h(n,t)==trns(n,t));
            correct5=correct5+1*(sum(trns_h(n,max(2,t-2):t+2)==trns(n,t))>0);
        end
    end
end


%% print results
disp(['for Table A5: learning model with direct diffusion and distance weight = ',num2str(diffRadius)])
disp(' ')
disp(['observations = ',num2str(sum(sum(M(:,1:50))))])
disp(['log-likelihood = ',num2str(ll)])
disp(['% correctly predicted choices = ',num2str(sum(sum(M(:,1:50).*(D(:,1:50)==D_hat(:,1:50))))/sum(sum(M(:,1:50))))])
disp(['% correctly predicted transitions within 0 years = ',num2str(correct0/sum(sum(abs(trns(:,1:50)))))])
disp(['% correctly predicted transitions within 5 years = ',num2str(correct5/sum(sum(abs(trns(:,1:50)))))])
disp(' ')
disp(' ')



